import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from scipy.interpolate import griddata
from scipy.stats import gaussian_kde


class HPOVisualizer:
    def __init__(self, trial_history):
        records = []
        for record in trial_history:
            row = record["config"].copy()
            row["score"] = record["score"]
            records.append(row)
        self.df = pd.DataFrame(records)

    def plot_optimization_history(self, show_line=True, show_density=True, cmap="viridis"):
        """
        Plot optimization history. If show_density is True, use color to indicate point density.
        """
        plt.figure(figsize=(8, 4))
        x = self.df.index.values
        y = self.df["score"].values

        if show_density:
            # Estimate point density using a 1D kernel density estimate on x/y pairs
            xy = np.vstack([x, y])
            z = gaussian_kde(xy)(xy)
            idx = z.argsort()
            x, y, z = x[idx], y[idx], z[idx]
            scatter = plt.scatter(x, y, c=z, s=40, cmap=cmap, edgecolor="none")
            cbar = plt.colorbar(scatter)
            cbar.set_label("Point Density")
        else:
            if show_line:
                plt.plot(x, y, marker="o")
            else:
                plt.scatter(x, y, s=40)

        plt.xlabel("Trial")
        plt.ylabel("Objective Score")
        plt.title("Optimization History")
        plt.grid(True)
        plt.tight_layout()
        plt.show()

    def plot_contour(self, param_x, param_y):
        plt.figure(figsize=(8, 6))
        try:
            pivot = self.df.pivot_table(index=param_y, columns=param_x, values="score")
            X, Y = np.meshgrid(pivot.columns, pivot.index)
            Z = pivot.values
            levels = 20
            contour_filled = plt.contourf(X, Y, Z, levels=levels, cmap="Spectral", alpha=0.85)

            contour_lines = plt.contour(X, Y, Z, levels=levels, colors="black", linewidths=0.5)

            plt.clabel(contour_lines, inline=True, fontsize=8, fmt="%.2f")

            cbar = plt.colorbar(contour_filled)
            cbar.set_label("Objective Score")

            plt.xlabel(param_x)
            plt.ylabel(param_y)
            plt.title(f"Enhanced Contour Plot: {param_x} vs {param_y}")

            plt.grid(visible=False)
            plt.tight_layout()
            plt.show()

        except Exception as e:
            print(f"Cannot plot contour for {param_x} vs {param_y}: {e}")

    def plot_slice(self, param):
        plt.figure(figsize=(6, 4))
        sns.scatterplot(x=self.df[param], y=self.df["score"])
        plt.xlabel(param)
        plt.ylabel("Objective Score")
        plt.title(f"Slice Plot: {param}")
        plt.grid(True)
        plt.show()

    def plot_surface(self, param_x, param_y, show_points=True, z_range_truncate=False, z_range=(0, 1)):
        """
        3D surface plot for two hyperparameters.
        z_range: tuple, (min, max) for Z value clipping. Default (0, 1).
        """
        try:
            X = self.df[param_x].values
            Y = self.df[param_y].values
            Z = self.df["score"].values

            xi = np.linspace(X.min(), X.max(), 50)
            yi = np.linspace(Y.min(), Y.max(), 50)
            Xi, Yi = np.meshgrid(xi, yi)
            Zi = griddata((X, Y), Z, (Xi, Yi), method="cubic")

            # Clip Zi to z_range if not None
            if z_range is not None and z_range_truncate:
                Zi = np.clip(Zi, z_range[0], z_range[1])

            fig = plt.figure(figsize=(10, 7))
            ax = fig.add_subplot(111, projection="3d")
            surf = ax.plot_surface(Xi, Yi, Zi, cmap="Spectral", alpha=0.9, edgecolor="none")

            if show_points:
                ax.scatter(X, Y, Z, color="black", s=20)

            fig.colorbar(surf, shrink=0.5, aspect=5, label="Objective Score")

            ax.set_xlabel(param_x)
            ax.set_ylabel(param_y)
            ax.set_zlabel("Objective Score")
            ax.set_title(f"3D Surface Plot: {param_x} vs {param_y} ")

            plt.tight_layout()
            plt.show()

        except Exception as e:
            print(f"Cannot plot 3D surface for {param_x} vs {param_y}: {e}")

    def plot_regret_curve(self, regret_curve):
        if regret_curve is None:
            print("[Warning] Cannot plot regret curve: best_known_optimum not provided or regret data unavailable.")
            return

        plt.figure(figsize=(8, 4))
        plt.plot(range(len(regret_curve)), regret_curve)
        plt.xlabel("Trial")
        plt.ylabel("Regret")
        plt.title("Regret Curve")
        plt.grid(True)
        plt.tight_layout()
        plt.show()

    def summary(self, param_x=None, param_y=None, z_range_truncate=True, z_range=(0, 1), regret_curve=None):
        """
        Show a summary of the optimization results, including:
        - Optimization history plot
        - Contour plot (for the first two parameters by default)
        - Surface plot (for the first two parameters by default)

        Parameters:
            param_x (str, optional): Name of the first hyperparameter for 2D/3D visualization. Defaults to the first column.
            param_y (str, optional): Name of the second hyperparameter for 2D/3D visualization. Defaults to the second column.
            z_range_truncate (bool, optional): Whether to clip the Z values (objective scores) in the surface plot to the range specified by z_range. Default is True.
            z_range (tuple, optional): (min, max) range for Z value clipping in the surface plot. Default is (0, 1).
        """
        print("\n📊 HPO Visualizer Summary and Plots")
        print("=" * 40)
        if self.df.empty:
            print("No trial history available.")
            return self
        print("=" * 40)
        # 1. Optimization history
        self.plot_optimization_history(show_line=False)
        # Select the first two hyperparameters for 2D visualization by default
        param_cols = [c for c in self.df.columns if c != "score"]
        if param_x is None or param_y is None:
            if len(param_cols) >= 2:
                param_x, param_y = param_cols[0], param_cols[1]
            else:
                print("Not enough hyperparameters for 2D plots.")
                return self
        # 2. Slice plot
        self.plot_slice(param_x)
        # 3. Surface plot
        self.plot_surface(param_x, param_y, show_points=True, z_range_truncate=z_range_truncate, z_range=z_range)
        # 4. Regret curve
        self.plot_regret_curve(regret_curve=regret_curve)
        return self
